{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:12:29.874400Z",
     "start_time": "2020-10-30T12:12:27.918175Z"
    }
   },
   "outputs": [],
   "source": [
    "import sklearn\n",
    "from sklearn.model_selection import train_test_split\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:10.015884Z",
     "start_time": "2020-10-30T12:13:10.000646Z"
    }
   },
   "outputs": [],
   "source": [
    "# 1. data pre-processing\n",
    "iris = sns.load_dataset('iris')\n",
    "X_iris = iris.drop('species', axis=1)\n",
    "y_iris = iris['species']\n",
    "\n",
    "Xtrain, Xtest, ytrain, ytest = train_test_split(X_iris, y_iris, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:17.716949Z",
     "start_time": "2020-10-30T12:13:17.700958Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal_length</th>\n",
       "      <th>sepal_width</th>\n",
       "      <th>petal_length</th>\n",
       "      <th>petal_width</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>54</th>\n",
       "      <td>6.5</td>\n",
       "      <td>2.8</td>\n",
       "      <td>4.6</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108</th>\n",
       "      <td>6.7</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.8</td>\n",
       "      <td>1.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>6.8</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.5</td>\n",
       "      <td>2.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>6.0</td>\n",
       "      <td>2.2</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.8</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>137</th>\n",
       "      <td>6.4</td>\n",
       "      <td>3.1</td>\n",
       "      <td>5.5</td>\n",
       "      <td>1.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>4.9</td>\n",
       "      <td>1.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>140</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.1</td>\n",
       "      <td>5.6</td>\n",
       "      <td>2.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>37</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>112 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal_length  sepal_width  petal_length  petal_width\n",
       "54            6.5          2.8           4.6          1.5\n",
       "108           6.7          2.5           5.8          1.8\n",
       "112           6.8          3.0           5.5          2.1\n",
       "17            5.1          3.5           1.4          0.3\n",
       "119           6.0          2.2           5.0          1.5\n",
       "..            ...          ...           ...          ...\n",
       "133           6.3          2.8           5.1          1.5\n",
       "137           6.4          3.1           5.5          1.8\n",
       "72            6.3          2.5           4.9          1.5\n",
       "140           6.7          3.1           5.6          2.4\n",
       "37            4.9          3.6           1.4          0.1\n",
       "\n",
       "[112 rows x 4 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xtrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:24.555858Z",
     "start_time": "2020-10-30T12:13:24.550866Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "54     versicolor\n",
       "108     virginica\n",
       "112     virginica\n",
       "17         setosa\n",
       "119     virginica\n",
       "          ...    \n",
       "133     virginica\n",
       "137     virginica\n",
       "72     versicolor\n",
       "140     virginica\n",
       "37         setosa\n",
       "Name: species, Length: 112, dtype: object"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ytrain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:40.317537Z",
     "start_time": "2020-10-30T12:13:40.312719Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import GaussianNB   # 2. choose model class\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:48.262705Z",
     "start_time": "2020-10-30T12:13:48.258633Z"
    }
   },
   "outputs": [],
   "source": [
    "model = GaussianNB()                                             # 3. instantiate model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:13:55.170320Z",
     "start_time": "2020-10-30T12:13:55.160711Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GaussianNB(priors=None, var_smoothing=1e-09)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(Xtrain, ytrain)                                           # 4. fit model to data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:14:07.806271Z",
     "start_time": "2020-10-30T12:14:07.801996Z"
    }
   },
   "outputs": [],
   "source": [
    "y_model = model.predict(Xtest)                           # 5. predict on new data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:14:15.524903Z",
     "start_time": "2020-10-30T12:14:15.521376Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['setosa', 'versicolor', 'versicolor', 'setosa', 'virginica',\n",
       "       'versicolor', 'virginica', 'setosa', 'setosa', 'virginica',\n",
       "       'versicolor', 'setosa', 'virginica', 'versicolor', 'versicolor',\n",
       "       'setosa', 'versicolor', 'versicolor', 'setosa', 'setosa',\n",
       "       'versicolor', 'versicolor', 'virginica', 'setosa', 'virginica',\n",
       "       'versicolor', 'setosa', 'setosa', 'versicolor', 'virginica',\n",
       "       'versicolor', 'virginica', 'versicolor', 'virginica', 'virginica',\n",
       "       'setosa', 'versicolor', 'setosa'], dtype='<U10')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:14:34.489423Z",
     "start_time": "2020-10-30T12:14:34.482380Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "setosa versicolor versicolor setosa virginica versicolor virginica setosa setosa virginica versicolor setosa virginica versicolor versicolor setosa versicolor versicolor setosa setosa versicolor versicolor versicolor setosa virginica versicolor setosa setosa versicolor virginica versicolor virginica versicolor virginica virginica setosa versicolor setosa\n"
     ]
    }
   ],
   "source": [
    "print(*ytest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-30T12:14:52.199570Z",
     "start_time": "2020-10-30T12:14:52.193626Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9736842105263158"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "accuracy_score(ytest, y_model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
